{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Climate event detection task\n",
    "Sandbox for preprocessing and first learning test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import shutil\n",
    "import sys\n",
    "sys.path.append('../..')\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\"  # change to chosen GPU to use, nothing if work on CPU\n",
    "\n",
    "# on CLI,  export LD_LIBRARY_PATH=/usr/local/cuda-9.0/extras/CUPTI/lib64:$LD_LIBRARY_PATH\n",
    "\n",
    "import numpy as np\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "import healpy as hp\n",
    "import pandas as pd\n",
    "\n",
    "from tqdm import tqdm\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "import cartopy.crs as ccrs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import h5py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from deepsphere import models, experiment_helper, plot, utils\n",
    "from deepsphere.data import LabeledDatasetWithNoise, LabeledDataset\n",
    "import hyperparameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "List of parameters:\n",
    "* TMQ: Total (vertically integrated) precipitatable water\n",
    "* U850: Zonal wind at 850 mbar pressure surface\n",
    "* V850: Meridional wind at 850 mbar pressure surface\n",
    "* UBOT: Lowest model level zonal wind\n",
    "* VBOT: Lowest model level meridional wind\n",
    "* QREFHT: Reference height humidity\n",
    "* PS: Surface pressure\n",
    "* PSL: sea level pressure\n",
    "* T200: temp at 200 mbar pressure surface\n",
    "* T500: temp at 500 mbar pressure surface\n",
    "* PRECT: Total (convective and large-scale) precipitation rate (liq + ice)\n",
    "* TS: Surface temperature (radiative)\n",
    "* Z100: Geopotential Z at 100 mbar pressure surface\n",
    "* Z200: Geopotential Z at 200 mbar pressure surface\n",
    "* ZBOT: Lowest model level height\n",
    "\n",
    "resolution of 768 x 1152 equirectangular grid (25-km at equator)\n",
    "\n",
    "The labels are 0 for background class, 1 for tropical cyclone, and 2 for atmoshperic river"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = '../../data/Climate/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "year, month, day, hour, run = 2106, 1, 1, 0, 1\n",
    "datapath = '../../data/Climate/data_5_all/data-{}-{:0>2d}-{:0>2d}-{:0>2d}-{}-mesh.npz'.format(year, month, day, hour, run)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plop = np.load(datapath)\n",
    "data = plop[\"data\"]\n",
    "labels = plop[\"labels\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stats = h5py.File('../../data/Climate/stats.h5')\n",
    "stats = stats['climate'][\"stats\"] # (16 X 4) (mean, max, min, std)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "year, month, day, hour, run = 2106, 1, 1, 0, 1\n",
    "datapath = '../../data/Climate/data-{}-{:0>2d}-{:0>2d}-{:0>2d}-{}.h5'.format(year, month, day, hour, run)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "h5f = h5py.File(datapath)\n",
    "data = h5f['climate'][\"data\"] # (16,768,1152) numpy array\n",
    "labels = h5f['climate'][\"labels\"] # (768,1152) numpy array"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lon_ = np.arange(1152)/1152*360\n",
    "lat_ = np.arange(768)/768*180-90\n",
    "lon, lat = np.meshgrid(lon_, lat_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from deepsphere.utils import icosahedron_graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = icosahedron_graph(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "icolong, icolat = np.rad2deg(g.lon), np.rad2deg(g.lat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(4, 4))\n",
    "ax = fig.add_subplot(1, 1, 1, projection=ccrs.Orthographic(90, 0))\n",
    "ax.set_global()\n",
    "ax.coastlines(linewidth=2)\n",
    "\n",
    "# zmin, zmax = -20, 40\n",
    "\n",
    "scat1 = plt.scatter(lon, lat, s=1, rasterized=True,\n",
    "            c=data[0,:,:], cmap=plt.get_cmap('RdYlBu_r'), alpha=1, transform=ccrs.PlateCarree())\n",
    "\n",
    "AR = labels[:,:]==1\n",
    "TC = labels[:,:]==2\n",
    "scat2 = ax.scatter(lon[AR], lat[AR], s=0.5, color='c', label='AR', transform=ccrs.PlateCarree())\n",
    "            #c=labels[show], cmap=plt.get_cmap('cool'), alpha=0.6, transform=ccrs.PlateCarree())\n",
    "scat3 = ax.scatter(lon[TC], lat[TC], s=0.5, color='m', label='TC', transform=ccrs.PlateCarree())\n",
    "            #c=labels[show], cmap=plt.get_cmap('cool'), alpha=0.6, transform=ccrs.PlateCarree())\n",
    "ax.legend(markerscale=5, fontsize=10, loc=1, frameon=False, ncol=1, bbox_to_anchor=(0.1, 0.18))\n",
    "ticks = range(np.min(data[0,:,:]).astype(int), np.max(data[0,:,:]).astype(int), 20)\n",
    "cb = plt.colorbar(scat1, ax=ax, orientation=\"horizontal\",anchor=(1.0,0.0), shrink=0.7, pad=0.05, ticks=ticks)\n",
    "cb.ax.tick_params(labelsize=10)\n",
    "cb.ax.set_xticklabels([f'${t}mm$' for t in ticks[1:]])\n",
    "\n",
    "# cb = fig.colorbar(sc, ax=ax, orientation='horizontal', fraction=0.02, aspect=40, pad=0.03, ticks=ticks)\n",
    "\n",
    "ax.text(0, 7e6, f'HAPPI20 Climate, TMQ, {year}-{month:02d}-{day:02d}-{hour:02d}-{run}', horizontalalignment='center')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())\n",
    "\n",
    "ax.set_global()\n",
    "# ax.stock_img()\n",
    "ax.coastlines()\n",
    "\n",
    "plt.scatter(lon, lat, s=1,\n",
    "            c=data[0,:,:], cmap=plt.get_cmap('RdYlBu_r'), alpha=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())\n",
    "\n",
    "ax.set_global()\n",
    "# ax.stock_img()\n",
    "ax.coastlines()\n",
    "\n",
    "plt.scatter(lon, lat, s=1,\n",
    "            c=labels[:,:], cmap=plt.get_cmap('RdYlBu_r'), alpha=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())\n",
    "\n",
    "ax.set_global()\n",
    "# ax.stock_img()\n",
    "ax.coastlines()\n",
    "\n",
    "plt.scatter(icolong, icolat, s=20,\n",
    "            c=data[0,:], cmap=plt.get_cmap('RdYlBu_r'), alpha=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())\n",
    "\n",
    "ax.set_global()\n",
    "# ax.stock_img()\n",
    "ax.coastlines()\n",
    "\n",
    "plt.scatter(icolong, icolat, s=20,\n",
    "            c=labels[0,:], cmap=plt.get_cmap('RdYlBu_r'), alpha=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.interpolate import griddata, RectBivariateSpline, RegularGridInterpolator, LinearNDInterpolator, interp2d, NearestNDInterpolator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.interpolate import griddata#, RectBivariateSpline, RegularGridInterpolator\n",
    "Nside = 32\n",
    "pix = np.arange(12*Nside**2)\n",
    "coords_hp = hp.pix2ang(Nside, pix, nest=True, lonlat=True)\n",
    "coords_hp = np.asarray(coords_hp).T\n",
    "# lon_rad, lat_rad = np.deg2rad(lon), np.deg2rad(lat)\n",
    "coords_map = hp.ang2vec(lon, lat, lonlat=True).reshape((-1, 3))\n",
    "coords_map = np.stack([lon, lat], axis=-1).reshape((-1, 2))\n",
    "# map_hp = griddata(coords_map, images[0,0].flatten(), coords_hp, 'linear')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = time.time()\n",
    "map_hp1 = griddata(coords_map, data[0].flatten(), coords_hp, 'linear')\n",
    "print(\"time taken:\", time.time()-t)\n",
    "\n",
    "t = time.time()\n",
    "f = RegularGridInterpolator((lon_, lat_), data[0].T)\n",
    "map_hp3 = f(coords_hp)\n",
    "print(\"time taken:\", time.time()-t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())\n",
    "\n",
    "ax.set_global()\n",
    "# ax.stock_img()\n",
    "ax.coastlines()\n",
    "\n",
    "plt.scatter(coords_hp[:,0], coords_hp[:,1], s=10,\n",
    "            c=map_hp3, cmap=plt.get_cmap('RdYlBu_r'), alpha=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coords_map.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f = NearestNDInterpolator(coords_map, labels[:].flatten(), rescale=False)\n",
    "new_labels = f(coords_hp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())\n",
    "\n",
    "ax.set_global()\n",
    "# ax.stock_img()\n",
    "ax.coastlines()\n",
    "\n",
    "plt.scatter(coords_hp[:,0], coords_hp[:,1], s=10,\n",
    "            c=new_labels, cmap=plt.get_cmap('RdYlBu_r'), alpha=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from glob import glob\n",
    "year = 2106\n",
    "path = '../../data/Climate/data-{}-01*.h5'.format(year)\n",
    "files = glob(path)\n",
    "datas = np.zeros((len(files),16,768,1152))\n",
    "labels = np.zeros((len(files),768,1152))\n",
    "for i, file in enumerate(files):\n",
    "#     _, _, month, day, hour, run = file.split('-')\n",
    "#     month, day, hour, run = int(month), int(day), int(hour), int(run[0])\n",
    "    data = h5py.File(file)\n",
    "    datas[i] = data['climate']['data']\n",
    "    labels[i] = data['climate']['labels']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from glob import glob\n",
    "year = 2106\n",
    "path = '../../data/Climate/data_5_all/data-{}-*.npz'.format(year)\n",
    "files = glob(path)\n",
    "datas = np.zeros((len(files),16,10242))\n",
    "labels = np.zeros((len(files),3,10242))\n",
    "for i, file in enumerate(files):\n",
    "#     _, _, month, day, hour, run = file.split('-')\n",
    "#     month, day, hour, run = int(month), int(day), int(hour), int(run[0])\n",
    "    data = np.load(file)\n",
    "    datas[i] = data['data']\n",
    "    labels[i] = data['labels']\n",
    "#     datas.append(data)\n",
    "#     labels.append(label)\n",
    "# datas = np.stack(datas)\n",
    "# labels = np.stack(labels)\n",
    "labels = np.argmax(labels, axis=1)\n",
    "datas = np.transpose(datas, axes=(0,2,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "limit=6000\n",
    "x_train = datas[:limit,:,:]\n",
    "labels_train = labels[:limit,:]\n",
    "x_val = datas[limit:,:,:]\n",
    "labels_val = labels[limit:,:]\n",
    "\n",
    "training = LabeledDataset(x_train, labels_train)\n",
    "validation = LabeledDataset(x_val, labels_val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Jiang separation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "precomp_mean = [26.160023, 0.98314494, 0.116573125, -0.45998842, 0.1930554, 0.010749293, 98356.03, 100982.02, 216.13145, 258.9456, 3.765611e-08, 288.82578, 288.03925, 342.4827, 12031.449, 63.435772]\n",
    "precomp_std =  [17.04294, 8.164175, 5.6868863, 6.4967732, 5.4465833, 0.006383436, 7778.5957, 3846.1863, 9.791707, 14.35133, 1.8771327e-07, 19.866386, 19.094095, 624.22406, 679.5602, 4.2283397]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rotmat = np.array([[np.cos(np.pi/4),np.sin(np.pi/4)],\n",
    "                    [-np.sin(np.pi/4),np.cos(np.pi/4)]])\n",
    "# change to magnitude"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = {}\n",
    "for partition in ['val']:\n",
    "    with open(path+'data_5_all/'+partition+\".txt\", \"r\") as f:\n",
    "        lines = f.readlines()\n",
    "    flist = [os.path.join(path, 'data_5_all', l.replace('\\n', '')) for l in lines]\n",
    "    data[partition] = {'data': np.zeros((len(flist),10242,16)),\n",
    "                       'labels': np.zeros((len(flist),10242))}\n",
    "    for i, f in enumerate(flist):\n",
    "        file = np.load(f)\n",
    "        data[partition]['data'][i] = (file['data'].T - precomp_mean) / precomp_std\n",
    "        data[partition]['data'][i,:,1] = np.arctan2(data[partition]['data'][i,:,1], data[partition]['data'][i,:,2])# data[partition]['data'][i,:,1:3] @ rotmat\n",
    "        data[partition]['data'][i,:,2] = data[partition]['data'][i,:,1]\n",
    "        data[partition]['data'][i,:,3] = np.arctan2(data[partition]['data'][i,:,3], data[partition]['data'][i,:,4]) # @ rotmat\n",
    "        data[partition]['data'][i,:,4] = data[partition]['data'][i,:,3]\n",
    "        data[partition]['labels'][i] = np.argmax(file['labels'].astype(np.int), axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train = data['train']['data']\n",
    "labels_train = data['train']['labels']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_test = data['test']['data']\n",
    "labels_test = data['test']['labels']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# training = LabeledDataset(data['train']['data'], data['train']['labels'])\n",
    "validation = LabeledDataset(data['val']['data'], data['val']['labels'])\n",
    "# test = LabeledDataset(data['test']['data'], data['test']['labels'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "del data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "TF dataset with Jiang separation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ClimateDataLoader import IcosahedronDataset, EquiangularDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training = IcosahedronDataset(path+'data_5_all/', 'train')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "validation = IcosahedronDataset(path+'data_5_all/', 'val')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf_train = training.get_tf_dataset(32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tqdm import tqdm\n",
    "\n",
    "data_next = tf_train.make_one_shot_iterator().get_next()\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "steps = training.N // 32 + 1\n",
    "with tf.Session(config=config) as sess:\n",
    "    sess.run(tf.global_variables_initializer())\n",
    "    try:\n",
    "        for i in tqdm(range(steps)):\n",
    "            out = sess.run(data_next)\n",
    "    except tf.errors.OutOfRangeError:\n",
    "        print(\"Done\") "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_local = EquiangularDataset(path, s3=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "validation_local = EquiangularDataset(path, 'val', s3=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf_local = training_local.get_tf_dataset(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pygsp.graphs import SphereEquiangular"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g2 = SphereEquiangular(bandwidth=(384, 576), sampling='SOFT')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "glong, glat = np.rad2deg(g2.lon), np.rad2deg(g2.lat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tqdm import tqdm\n",
    "from time import time\n",
    "\n",
    "data_next = tf_local.make_one_shot_iterator().get_next()\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "steps = training_local.N // 32 + 1\n",
    "t_start = time()\n",
    "with tf.Session(config=config) as sess:\n",
    "    sess.run(tf.global_variables_initializer())\n",
    "    try:\n",
    "        for i in tqdm(range(steps)):\n",
    "#             t_begin = time()\n",
    "            out = sess.run(data_next)\n",
    "#             print('loop time: ', time()-t_begin)\n",
    "            fig = plt.figure(figsize=(20, 10))\n",
    "            ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())\n",
    "\n",
    "            ax.set_global()\n",
    "            # ax.stock_img()\n",
    "            ax.coastlines()\n",
    "\n",
    "            plt.scatter(glong, glat, s=1,\n",
    "                        c=out[0][0,:,:,0], cmap=plt.get_cmap('RdYlBu_r'), alpha=1)\n",
    "            break\n",
    "    except tf.errors.OutOfRangeError as e:\n",
    "        print(\"Done\")\n",
    "print(time()-t_start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_s3 = EquiangularDataset(path, s3=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf_s3 = training_s3.get_dataset_s3(32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tqdm import tqdm\n",
    "from time import time\n",
    "\n",
    "data_next = tf_s3.make_one_shot_iterator().get_next()\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "steps = training_s3.N // 32 + 1\n",
    "t_start = time()\n",
    "with tf.Session(config=config) as sess:\n",
    "    sess.run(tf.global_variables_initializer())\n",
    "    try:\n",
    "        for i in tqdm(range(steps)):\n",
    "            t_begin = time()\n",
    "            out = sess.run(data_next)\n",
    "            print('loop time: ', time()-t_begin)\n",
    "    except tf.errors.OutOfRangeError as e:\n",
    "        print(\"Done\")\n",
    "print(time()-t_start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.python.lib.io import file_io\n",
    "gstats = file_io.stat('s3://10380-903b2ba14e0d980c25436f9ca5bb29f5/Datasets/Climate/data-2106-02-23-03-1.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gstats.length"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EXP_NAME = 'TestClimate_nopooling_ico_4layers_k4'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(bw1, bw2) = (384, 576)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "params = {'nsides': [5, 5, 4, 3, 2, 1, 0, 0],\n",
    "          'F': [32, 64, 128, 256, 512, 512, 512],#np.max(labels_train).astype(int)+1],\n",
    "          'K': [4]*7,\n",
    "          'batch_norm': [True]*7}\n",
    "params['sampling'] = 'icosahedron'\n",
    "params['dir_name'] = EXP_NAME\n",
    "params['num_feat_in'] = 16 # x_train.shape[-1] # 2*days_pred+3\n",
    "params['conv'] = 'chebyshev5'\n",
    "params['pool'] = 'average'\n",
    "params['activation'] = 'relu'\n",
    "params['statistics'] = None#'mean'\n",
    "params['regularization'] = 0\n",
    "params['dropout'] = 1\n",
    "params['num_epochs'] = 25  # Number of passes through the training data.\n",
    "params['batch_size'] = 32\n",
    "params['scheduler'] = lambda step: tf.train.exponential_decay(1e-3, step, decay_steps=2000, decay_rate=1)\n",
    "#params['optimizer'] = lambda lr: tf.train.GradientDescentOptimizer(lr)\n",
    "# params['optimizer'] = lambda lr: tf.train.AdamOptimizer(lr, beta1=0.9, beta2=0.999, epsilon=1e-8)\n",
    "params['optimizer'] = lambda lr: tf.train.RMSPropOptimizer(lr, decay=0.9, momentum=0.)\n",
    "n_evaluations = 100\n",
    "params['eval_frequency'] = int(params['num_epochs'] * (training.N) / params['batch_size'] / n_evaluations)\n",
    "params['M'] = []\n",
    "params['Fseg'] = 3 # np.max(labels_train).astype(int)+1\n",
    "params['dense'] = True\n",
    "params['tf_dataset'] = training.get_tf_dataset(params['batch_size'])\n",
    "# params['profile'] = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "params = {'nsides': [5, 5, 5, 5, 5],\n",
    "          'F': [32, 64, 128, 256],#np.max(labels_train).astype(int)+1],\n",
    "          'K': [4]*4,\n",
    "          'batch_norm': [True]*4}\n",
    "params['sampling'] = 'icosahedron'\n",
    "params['dir_name'] = EXP_NAME\n",
    "params['num_feat_in'] = 16 # x_train.shape[-1] # 2*days_pred+3\n",
    "params['conv'] = 'chebyshev5'\n",
    "params['pool'] = 'average'\n",
    "params['activation'] = 'relu'\n",
    "params['statistics'] = None#'mean'\n",
    "params['regularization'] = 0\n",
    "params['dropout'] = 1\n",
    "params['num_epochs'] = 30  # Number of passes through the training data.\n",
    "params['batch_size'] = 8\n",
    "params['scheduler'] = lambda step: tf.train.exponential_decay(1e-3, step, decay_steps=2000, decay_rate=1)\n",
    "#params['optimizer'] = lambda lr: tf.train.GradientDescentOptimizer(lr)\n",
    "# params['optimizer'] = lambda lr: tf.train.AdamOptimizer(lr, beta1=0.9, beta2=0.999, epsilon=1e-8)\n",
    "params['optimizer'] = lambda lr: tf.train.RMSPropOptimizer(lr, decay=0.9, momentum=0.)\n",
    "n_evaluations = 60\n",
    "params['eval_frequency'] = int(params['num_epochs'] * (training.N) / params['batch_size'] / n_evaluations)\n",
    "params['M'] = []\n",
    "params['Fseg'] = 3 # np.max(labels_train).astype(int)+1\n",
    "params['dense'] = True\n",
    "params['weighted'] = False\n",
    "params['tf_dataset'] = training.get_tf_dataset(params['batch_size'])\n",
    "# params['profile'] = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EXP_NAME = \"TestClimate_pooling_weight_equi\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cleanup before running again.\n",
    "shutil.rmtree('../../summaries/{}/'.format(EXP_NAME), ignore_errors=True)\n",
    "shutil.rmtree('../../checkpoints/{}/'.format(EXP_NAME), ignore_errors=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "params = {'nsides': [(bw1, bw2), (bw1, bw2),(bw1//4, bw2//4),(bw1//16, bw2//16),(bw1//16, bw2//16)],\n",
    "          'F': [8, 32, 64, 128],#np.max(labels_train).astype(int)+1],\n",
    "          'K': [5]*4,\n",
    "          'batch_norm': [True]*4}\n",
    "params['sampling'] = 'equiangular'\n",
    "params['dir_name'] = EXP_NAME\n",
    "params['num_feat_in'] = 16 # x_train.shape[-1] # 2*days_pred+3\n",
    "params['conv'] = 'chebyshev5'\n",
    "params['pool'] = 'average'\n",
    "params['activation'] = 'relu'\n",
    "params['statistics'] = None#'mean'\n",
    "params['regularization'] = 0\n",
    "params['dropout'] = 1\n",
    "params['num_epochs'] = 25  # Number of passes through the training data.\n",
    "params['batch_size'] = 1\n",
    "params['scheduler'] = lambda step: tf.train.exponential_decay(1e-3, step, decay_steps=2000, decay_rate=1)\n",
    "#params['optimizer'] = lambda lr: tf.train.GradientDescentOptimizer(lr)\n",
    "params['optimizer'] = lambda lr: tf.train.AdamOptimizer(lr, beta1=0.9, beta2=0.999, epsilon=1e-8)\n",
    "# params['optimizer'] = lambda lr: tf.train.RMSPropOptimizer(lr, decay=0.9, momentum=0.)\n",
    "n_evaluations = 100\n",
    "params['eval_frequency'] = int(params['num_epochs'] * (training_local.N) / params['batch_size'] / n_evaluations)\n",
    "params['M'] = []\n",
    "params['Fseg'] = 3 # np.max(labels_train).astype(int)+1\n",
    "params['dense'] = True\n",
    "params['tf_dataset'] = training_local.get_tf_dataset(params['batch_size'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print([12*nside**2 for nside in params['nsides']])\n",
    "model = models.deepsphere(**params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Jiang: 328,339\n",
    "\n",
    "DeepSphere ico deep: 12,926,432\n",
    "\n",
    "DeepSphere ico shallowe: 141,624\n",
    "\n",
    "DeepSphere equi: z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"the number of parameters in the model is: {:,}\".format(model.get_nbr_var()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model.fit(training, validation, use_tf_dataset=True, cache='TF', restore=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions = model.predict(x_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "probabilities = model.probs(x_test, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())\n",
    "\n",
    "ax.set_global()\n",
    "# ax.stock_img()\n",
    "ax.coastlines()\n",
    "\n",
    "plt.scatter(icolong, icolat, s=20,\n",
    "            c=predictions[0,:], cmap=plt.get_cmap('RdYlBu_r'), alpha=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())\n",
    "\n",
    "ax.set_global()\n",
    "# ax.stock_img()\n",
    "ax.coastlines()\n",
    "\n",
    "plt.scatter(icolong, icolat, s=20,\n",
    "            c=labels_test[0,:], cmap=plt.get_cmap('RdYlBu_r'), alpha=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(10, 10))\n",
    "ax = fig.add_subplot(1, 1, 1, projection=ccrs.Orthographic(0, 0))\n",
    "ax.set_global()\n",
    "ax.coastlines(linewidth=2)\n",
    "\n",
    "# zmin, zmax = -20, 40\n",
    "\n",
    "plt.scatter(icolong, icolat, s=100,\n",
    "            c=predictions[0,:], cmap=plt.get_cmap('RdYlBu_r'), alpha=1, transform=ccrs.PlateCarree())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score, average_precision_score\n",
    "from sklearn.preprocessing import label_binarize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def accuracy(pred_cls, true_cls, nclass=3):\n",
    "    accu = []\n",
    "    tot_int = 0\n",
    "    tot_cl = 0\n",
    "    for i in range(3):\n",
    "        intersect = np.sum(((pred_cls == i) * (true_cls == i)))\n",
    "        thiscls = np.sum(true_cls == i)\n",
    "        accu.append(intersect / thiscls * 100)\n",
    "        tot_int += intersect\n",
    "        tot_cl += thiscls\n",
    "    return np.array(accu), np.mean(accu), tot_int/tot_cl * 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def average_precision(score_cls, true_cls, nclass=3):\n",
    "    score = score_cls\n",
    "    true = label_binarize(true_cls.reshape(-1), classes=[0, 1, 2])\n",
    "    score = score.reshape(-1, nclass)\n",
    "    AP = average_precision_score(true, score, None)\n",
    "    return AP, np.mean(AP)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "accuracy(predictions, labels_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "average_precision(probabilities, labels_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "mAP for positives class is 0.7541626"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.utils.class_weight import compute_class_weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "compute_class_weight('balanced', [0,1,2], labels_train.flatten())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
